import torch
import torch.nn.functional as F

import numpy as np
import os

import torch_geometric.utils
from torch_geometric.utils import negative_sampling


# %%
class PairwiseDistanceTask(torch.nn.Module):
    num_classes = 4  # distance 1, 2, 3, >3
    num_samples = 4000

    def __init__(self, data, embedding_size, device):
        super(PairwiseDistanceTask, self).__init__()
        self.name = 'pair-wise distance'
        self.data = data
        self.dataset_names = [dataset.name for dataset in self.data.datasets]
        self.device = device
        self.num_nodes = [dataset.data.num_nodes for dataset in self.data.datasets]
        self.predictor = torch.nn.Linear(embedding_size, self.num_classes).to(self.device)
        self.distance_node_pairs = self.calculate_distances()

    def sample(self, dataset_name):
        index = self.dataset_names.index(dataset_name)
        labels = []
        sampled_edges = []

        # positive sampling
        for i in range(self.num_classes - 1):
            idx_selected = np.random.default_rng().choice(self.num_nodes[index],
                                                          int(self.num_samples / (self.num_classes - 1)),
                                                          replace=False).astype(np.int32)
            # print('index',index, self.distance_node_pairs[index].keys())
            node_pairs = self.distance_node_pairs[index][i + 1]
            labels.append(torch.ones(len(idx_selected), dtype=torch.long) * i)
            sampled_edges.append(node_pairs[:, idx_selected])

        # negative sampling
        if torch.max(self.distance_node_pairs[index][100]) >= self.num_nodes[index]:
            print(f'found {torch.max(self.distance_node_pairs[index][100])} nodes for a graph with'
                  f'{self.num_nodes[index]} nodes')
        neg_edges = negative_sampling(edge_index=self.distance_node_pairs[index][100],
                                      num_nodes=self.num_nodes[index], num_neg_samples=self.num_samples)
        sampled_edges.append(neg_edges)
        labels.append(torch.ones(neg_edges.shape[1], dtype=torch.long) * self.num_classes - 1)

        labels = torch.cat(labels).to(self.device)
        sampled_edges = torch.cat(sampled_edges, axis=1)
        return sampled_edges, labels

    def get_loss(self, embeddings, dataset_name):
        node_pairs, labels = self.sample(dataset_name)

        embeddings_1 = embeddings[node_pairs[0]]
        embeddings_2 = embeddings[node_pairs[1]]
        embeddings = self.predictor(torch.abs(embeddings_1 - embeddings_2))
        output = F.log_softmax(embeddings, dim=1)
        loss = F.nll_loss(output, labels)
        return loss

    def calculate_distances(self):
        adjs = []
        for dataset in self.data.datasets:
            if not os.path.exists(f'saved/{dataset.name}_pairdis.pt'):
                print("Calculating pair-wise distances")
                adj = torch_geometric.utils.to_dense_adj(dataset.data.edge_index)[0].cpu()
                adj_2 = adj @ adj
                adj_3 = adj_2 @ adj
                adj_2 = adj_2 + torch.eye(dataset.data.num_nodes)
                adj_3_aug = adj + adj_2 + adj_3
                adjs = {
                    1: adj.nonzero().reshape(2, -1),
                    2: adj_2.nonzero().reshape(2, -1),
                    3: adj_3.nonzero().reshape(2, -1),
                    100: adj_3_aug.nonzero().reshape(2, -1)}
                print(f'saving saved/{dataset.name}_pairdis.pt')
                torch.save(adjs, f'saved/{dataset.name}_pairdis.pt')
            else:
                print(f'loading saved/{dataset.name}_pairdis.pt')
                adjs.append(torch.load(f'saved/{dataset.name}_pairdis.pt'))
        return adjs


            # g = torch_geometric.utils.to_networkx(dataset.data)
            # shortest_paths = dict(networkx.all_pairs_shortest_path_length(g, cutoff=self.num_classes - 1))
            # node_pairs = [[] for _ in range(self.num_classes - 1)]
            # for node_id in range(len(shortest_paths)):
            #     shortest_path_node = shortest_paths[node_id]
            #     keys, vals = np.array(list(shortest_path_node.keys())), np.array(list(shortest_path_node.values()))
            #     for path_length in range(1, self.num_classes):
            #         mask = vals == path_length
            #         node_ids = keys[mask]
            #         for _ in node_ids:
            #             node_pairs[path_length - 1].append(np.array([node_id, _]))
            # all_node_pairs.append(node_pairs)
        # return all_node_pairs
